import cv2
import gym
import highway_env
#from Python.highway-env-master import setup.py
import sys

import highway_env
from stable_baselines3 import DDPG
import torch
import pprint
import winsound #to make beep when learning is done
from stable_baselines3 import DDPG
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy

env = gym.make("roundabout")

# 定义超参数
param = {
    "policy_kwargs": dict(net_arch=[256, 256]),
    "learning_starts": 5000,
    "buffer_size": 100000,
    "batch_size": 256,
    "gamma": 0.99,
    "tau": 0.001,
    "train_freq": 1000,
    "gradient_steps": 100,
    "action_noise": NormalActionNoise(mean=0, sigma=0.1),
    "verbose": 1
}

# 创建DDPG模型
model = DDPG("MlpPolicy", DummyVecEnv([lambda: env]), **param)

# 添加模型保存回调函数
checkpoint_callback = CheckpointCallback(save_freq=10000, save_path='./logs/', name_prefix='ddpg_roundabout')
callbacks = [checkpoint_callback]

model = DDPG("MlpPolicy", env, verbose=1, tensorboard_log="./logs/")
model.learn(total_timesteps=1000)

# 测试模型
# mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
# print(f"Test reward: {mean_reward} +/- {std_reward}")
model.save("roundabout_ddpg/model")

print()
print("Done Learning!!")
print()





########## Load and test saved model##############


#Change name of model.load:

model = DDPG.load("roundabout_ddpg/model")
#while True:
for f in range(40):
  done = truncated = False
  obs, info = env.reset()
  while not (done or truncated):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, done, truncated, info = env.step(action)#env.step(action.item(0))

    #print(reward)
    #print(info)
    #input("Press Enter to continue...")

    env.render()
    cur_frame = env.render(mode="rgb_array")



#cur_frame = env.render(mode="rgb_array")
#out.write(cur_frame)


print('DONE')


#print(env_reward())

#NOTE
#rewards is the gives rewards along different categories,
#reward combines the values from rewards into 1 value
#reward does this calculation using config and rewards
